{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cc1e2529-a2e4-4bf9-8c3b-b3a037e06fd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torch.nn\n",
    "import torch.utils.tensorboard\n",
    "\n",
    "import torchvision.transforms\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "efb32107-c090-43bf-9b5c-44ce2c5b4904",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "# 准备数据集\n",
    "train_data = torchvision.datasets.CIFAR10(\"../CIFAR10\", train=True,\n",
    "                                          transform=torchvision.transforms.ToTensor(), download=True)\n",
    "# cifar10是默认piltenser，要转化为pytorchtensor\n",
    "\n",
    "test_data = torchvision.datasets.CIFAR10(\"../CIFAR10\", train=False, download=True,\n",
    "                                         transform=torchvision.transforms.ToTensor())\n",
    "# print(len(train_data))\n",
    "# print(len(test_data))\n",
    "train_data_size = len(train_data)\n",
    "test_data_size = len(test_data)\n",
    "# 用dataloader加载数据集,加载为一组64个\n",
    "train_dataloader = DataLoader(dataset=train_data, batch_size=64)\n",
    "test_dataloader = DataLoader(dataset=test_data, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "67cca792-7dd3-4ead-b657-ca5336aa978b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 添加tensorboard可视化\n",
    "writer = torch.utils.tensorboard.SummaryWriter(\"./logs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "eb1bb815-35c3-4d5e-8a3b-64e3726088bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CIFAR10Classify01(\n",
       "  (model1): Sequential(\n",
       "    (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (3): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (5): Sigmoid()\n",
       "    (6): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (8): Flatten(start_dim=1, end_dim=-1)\n",
       "    (9): Linear(in_features=1024, out_features=64, bias=True)\n",
       "    (10): Linear(in_features=64, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 搭建神经网络\n",
    "class CIFAR10Classify01(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CIFAR10Classify01, self).__init__()\n",
    "        self.model1 = torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2, ),\n",
    "            torch.nn.MaxPool2d(kernel_size=2),\n",
    "            torch.nn.BatchNorm2d(32),\n",
    "            torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2, ),\n",
    "            torch.nn.MaxPool2d(kernel_size=2),\n",
    "            torch.nn.Sigmoid(),\n",
    "            torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2, ),\n",
    "            torch.nn.MaxPool2d(2),\n",
    "            torch.nn.Flatten(),\n",
    "            # 第一个linear的in_features可以由print(x.shape)得到\n",
    "            torch.nn.Linear(in_features=1024, out_features=64),\n",
    "            torch.nn.Linear(in_features=64, out_features=10)\n",
    "            # 分类问题的最终out_features和类数相同\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.model1(x)\n",
    "        return x\n",
    "    \n",
    "myModel = CIFAR10Classify01()\n",
    "myModel.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d605e974-7266-4ee3-ab72-954314233065",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义损失函数\n",
    "loss_fn = torch.nn.CrossEntropyLoss()\n",
    "loss_fn.to(device)\n",
    "\n",
    "# 定义优化器\n",
    "learning_rate = 1e-2\n",
    "optim = torch.optim.SGD(params=myModel.parameters(), lr=learning_rate)\n",
    "# SGD随机梯度下降"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4f6defce-7903-4db5-94e0-90433d4917fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------0------\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "running_mean should contain 3 elements not 32",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Input \u001b[1;32mIn [6]\u001b[0m, in \u001b[0;36m<cell line: 5>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     10\u001b[0m imgs \u001b[38;5;241m=\u001b[39m imgs\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m     11\u001b[0m targets \u001b[38;5;241m=\u001b[39m targets\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m---> 12\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmyModel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimgs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     13\u001b[0m \u001b[38;5;66;03m# 优化器梯度清零\u001b[39;00m\n\u001b[0;32m     14\u001b[0m optim\u001b[38;5;241m.\u001b[39mzero_grad()\n",
      "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1129\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[1;32mIn [4]\u001b[0m, in \u001b[0;36mCIFAR10Classify01.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     21\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m---> 22\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     23\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
      "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1129\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\container.py:139\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m    138\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 139\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m    140\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1129\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\batchnorm.py:168\u001b[0m, in \u001b[0;36m_BatchNorm.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    161\u001b[0m     bn_training \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_mean \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_var \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m    163\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    164\u001b[0m \u001b[38;5;124;03mBuffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be\u001b[39;00m\n\u001b[0;32m    165\u001b[0m \u001b[38;5;124;03mpassed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are\u001b[39;00m\n\u001b[0;32m    166\u001b[0m \u001b[38;5;124;03mused for normalization (i.e. in eval mode when buffers are not None).\u001b[39;00m\n\u001b[0;32m    167\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m--> 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_norm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    169\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    170\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# If buffers are not to be tracked, ensure that they won't be updated\u001b[39;49;00m\n\u001b[0;32m    171\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_mean\u001b[49m\n\u001b[0;32m    172\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack_running_stats\u001b[49m\n\u001b[0;32m    173\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    174\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_var\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack_running_stats\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    175\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    176\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    177\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbn_training\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    178\u001b[0m \u001b[43m    \u001b[49m\u001b[43mexponential_average_factor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    179\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    180\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\functional.py:2438\u001b[0m, in \u001b[0;36mbatch_norm\u001b[1;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[0;32m   2435\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training:\n\u001b[0;32m   2436\u001b[0m     _verify_batch_size(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39msize())\n\u001b[1;32m-> 2438\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_norm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   2439\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrunning_mean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrunning_var\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmomentum\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackends\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcudnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menabled\u001b[49m\n\u001b[0;32m   2440\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: running_mean should contain 3 elements not 32"
     ]
    }
   ],
   "source": [
    "# 设置训练网络参数\n",
    "train_step = 0\n",
    "test_step = 0\n",
    "epoch = 100\n",
    "for i in range(epoch):\n",
    "    train_loss = 0\n",
    "    print(\"------{}------\".format(i))\n",
    "    myModel.train()\n",
    "    for imgs, targets in train_dataloader:\n",
    "        imgs = imgs.to(device)\n",
    "        targets = targets.to(device)\n",
    "        output = myModel(imgs)\n",
    "        # 优化器梯度清零\n",
    "        optim.zero_grad()\n",
    "        # 计算损失\n",
    "        loss = loss_fn(output, targets)\n",
    "        # 反向传播算梯度\n",
    "        loss.backward()\n",
    "        # 优化器优化模型\n",
    "        optim.step()\n",
    "        train_loss += loss.item()\n",
    "        train_step += 1\n",
    "        # print(loss)\n",
    "    writer.add_scalar(\"trainloss\",  train_loss/len(train_data), train_step)\n",
    "    print(\"running_loss: \", train_loss/len(train_data))\n",
    "\n",
    "    # 验证每轮学习效果\n",
    "    test_loss = 0\n",
    "    total_accuracy = 0\n",
    "    with torch.no_grad():\n",
    "        # 取消梯度，验证时不调优\n",
    "        myModel.eval()\n",
    "        for imgs, targets in test_dataloader:\n",
    "            imgs = imgs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            output = myModel(imgs)\n",
    "            loss = loss_fn(output, targets)\n",
    "            test_loss += loss.item()\n",
    "            accuracy = (output.argmax(1) == targets).sum()\n",
    "            total_accuracy += accuracy\n",
    "            test_step += 1\n",
    "    writer.add_scalar(\"testloss\", test_loss/len(test_data), test_step)\n",
    "    writer.add_scalar(\"total_accuracy\", total_accuracy/len(test_data), test_step)\n",
    "    print(\"test_loss: \", test_loss/len(test_data))\n",
    "    print(\"total_accuracy: \", total_accuracy/len(test_data))\n",
    "    torch.save(myModel, \"./myModel_{}.pth\".format(i))\n",
    "\n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16cef185-cf09-4eab-b5a4-643ae14148fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext tensorboard\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3280638-14cd-47f8-9717-c6eb86613ca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from tensorboard import notebook\n",
    "notebook.list()\n",
    "notebook.start(\"--logdir ./logs --port=709\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
